class UnionFind {
public:
    UnionFind(int n) {
        p = vector<int>(n);
        size = vector<int>(n, 1);
        iota(p.begin(), p.end(), 0);
    }

    bool unite(int a, int b) {
        int pa = find(a), pb = find(b);
        if (pa == pb) {
            return false;
        }
        if (size[pa] > size[pb]) {
            p[pb] = pa;
            size[pa] += size[pb];
        } else {
            p[pa] = pb;
            size[pb] += size[pa];
        }
        return true;
    }

    int find(int x) {
        if (p[x] != x) {
            p[x] = find(p[x]);
        }
        return p[x];
    }

    int getSize(int root) {
        return size[root];
    }

private:
    vector<int> p, size;
};

class Solution {
public:
    int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
        int n = graph.size();
        bool s[n];
        memset(s, false, sizeof(s));
        for (int i : initial) {
            s[i] = true;
        }
        UnionFind uf(n);
        for (int i = 0; i < n; ++i) {
            if (!s[i]) {
                for (int j = i + 1; j < n; ++j) {
                    if (graph[i][j] && !s[j]) {
                        uf.unite(i, j);
                    }
                }
            }
        }
        unordered_set<int> g[n];
        int cnt[n];
        memset(cnt, 0, sizeof(cnt));
        for (int i : initial) {
            for (int j = 0; j < n; ++j) {
                if (!s[j] && graph[i][j]) {
                    g[i].insert(uf.find(j));
                }
            }
            for (int root : g[i]) {
                ++cnt[root];
            }
        }
        int ans = 0, mx = -1;
        for (int i : initial) {
            int t = 0;
            for (int root : g[i]) {
                if (cnt[root] == 1) {
                    t += uf.getSize(root);
                }
            }
            if (t > mx || (t == mx && i < ans)) {
                ans = i;
                mx = t;
            }
        }
        return ans;
    }
};